{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variational Quantum Classifier Feature Map Comparison\n",
    "\n",
    "Both the first-order and second-order expansion feature maps provided by `qiskit_machine_learning` use $n$ qubits to encode $n$-dim datapoints. However, raw feature vectors can also be directly used in `VQC` circuit constructions, requiring only $log_2(n)$ qubits to encode $n$-dim datapoints. \n",
    "\n",
    "### Experiment\n",
    "Below we compare the classification performance of `VQC` on the [Wine dataset](https://scikit-learn.org/stable/datasets/index.html#wine-dataset) using `RawFeatureVector` and `ZZFeatureMap` feature maps. As you'll see, the former leads to much better testing accuracy using only $2$ qubits, and taking less time. Note that very few training data points are being used.\n",
    "`VQC` uses [SamplerQNN](https://qiskit-community.github.io/qiskit-machine-learning/stubs/qiskit_machine_learning.neural_networks.SamplerQNN.html#samplerqnn) by default, that's why you don't need to specify a primitive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "from qiskit.circuit.library import ZZFeatureMap, TwoLocal\n",
    "\n",
    "from qiskit_machine_learning.algorithms.classifiers import VQC\n",
    "from qiskit_machine_learning.circuit.library import RawFeatureVector\n",
    "from qiskit_algorithms.optimizers import SPSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to get sklearn wine dataset, take the desired number of samples and features, scale it and plot it\n",
    "\n",
    "def wine(training_size, test_size, n_features, plot_data=False):\n",
    "    \"\"\" returns wine dataset \"\"\"\n",
    "    data, target = load_wine(return_X_y=True)\n",
    "    sample_train, sample_test, label_train, label_test = train_test_split(data, target, test_size=test_size, random_state=42)\n",
    "\n",
    "    std_scale = StandardScaler().fit(sample_train)\n",
    "    sample_train = std_scale.transform(sample_train)\n",
    "    sample_test = std_scale.transform(sample_test)\n",
    "\n",
    "    pca = PCA(n_components=n_features).fit(sample_train)\n",
    "    sample_train = pca.transform(sample_train)\n",
    "    sample_test = pca.transform(sample_test)\n",
    "\n",
    "    samples = np.append(sample_train, sample_test, axis=0)\n",
    "    minmax_scale = MinMaxScaler(feature_range=(-1, 1)).fit(samples)\n",
    "    sample_train = minmax_scale.transform(sample_train)\n",
    "    sample_test = minmax_scale.transform(sample_test)\n",
    "\n",
    "    sample_train = sample_train[:training_size, :]\n",
    "    label_train = label_train[:training_size]\n",
    "    sample_test = sample_test[:training_size, :]\n",
    "    label_test = label_test[:training_size]\n",
    "\n",
    "    if plot_data:\n",
    "        try:\n",
    "            import matplotlib.pyplot as plt\n",
    "        except ImportError:\n",
    "            raise NameError('Matplotlib not installed. Please install it before plotting')\n",
    "        for k in range(0, 3):\n",
    "            plt.scatter(sample_train[label_train == k, 0][:training_size],\n",
    "                        sample_train[label_train == k, 1][:training_size])\n",
    "\n",
    "        plt.title(\"PCA dim. reduced Wine dataset\")\n",
    "        plt.show()\n",
    "\n",
    "    return sample_train, label_train, sample_test, label_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = 4\n",
    "\n",
    "train_data, train_labels, test_data, test_labels = wine(\n",
    "    training_size=20,\n",
    "    test_size=10, n_features=n_features, plot_data=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try `RawFeatureVector` first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     »\n",
      "q_0: »\n",
      "     »\n",
      "q_1: »\n",
      "     »\n",
      "«     ┌──────────────────────────────────────────────────────────────────────────────────────────────────┐\n",
      "«q_0: ┤0                                                                                                 ├\n",
      "«     │  TwoLocal(θ[0],θ[1],θ[2],θ[3],θ[4],θ[5],θ[6],θ[7],θ[8],θ[9],θ[10],θ[11],θ[12],θ[13],θ[14],θ[15]) │\n",
      "«q_1: ┤1                                                                                                 ├\n",
      "«     └──────────────────────────────────────────────────────────────────────────────────────────────────┘\n",
      "Results with RawFeatureVector: train set accuracy -> 0.85, test set accuracy -> 0.7.\n"
     ]
    }
   ],
   "source": [
    "feature_map = RawFeatureVector(feature_dimension=n_features)\n",
    "\n",
    "ansatz = TwoLocal(feature_map.num_qubits, ['ry', 'rz'], 'cz', reps=3)\n",
    "print(ansatz) # this approach requires log2(n_features) = 2 qubits\n",
    "vqc = VQC(\n",
    "    feature_map=feature_map,\n",
    "    ansatz=ansatz,\n",
    "    loss=\"cross_entropy\",\n",
    "    optimizer=SPSA(maxiter=100)\n",
    ")\n",
    "\n",
    "vqc.fit(train_data, train_labels)\n",
    "\n",
    "print(f'Results with RawFeatureVector: train set accuracy -> {vqc.score(train_data, train_labels)}, test set accuracy -> {vqc.score(test_data, test_labels)}.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's try `ZZFeatureMap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     »\n",
      "q_0: »\n",
      "     »\n",
      "q_1: »\n",
      "     »\n",
      "q_2: »\n",
      "     »\n",
      "q_3: »\n",
      "     »\n",
      "«     ┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐\n",
      "«q_0: ┤0                                                                                                                                                                                                 ├\n",
      "«     │                                                                                                                                                                                                  │\n",
      "«q_1: ┤1                                                                                                                                                                                                 ├\n",
      "«     │  TwoLocal(θ[0],θ[1],θ[2],θ[3],θ[4],θ[5],θ[6],θ[7],θ[8],θ[9],θ[10],θ[11],θ[12],θ[13],θ[14],θ[15],θ[16],θ[17],θ[18],θ[19],θ[20],θ[21],θ[22],θ[23],θ[24],θ[25],θ[26],θ[27],θ[28],θ[29],θ[30],θ[31]) │\n",
      "«q_2: ┤2                                                                                                                                                                                                 ├\n",
      "«     │                                                                                                                                                                                                  │\n",
      "«q_3: ┤3                                                                                                                                                                                                 ├\n",
      "«     └──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘\n",
      "Results with ZZFeatureMap: train set accuracy -> 0.8, test set accuracy -> 0.3.\n"
     ]
    }
   ],
   "source": [
    "feature_map = ZZFeatureMap(n_features)\n",
    "\n",
    "ansatz = TwoLocal(feature_map.num_qubits, ['ry', 'rz'], 'cz', reps=3)\n",
    "print(ansatz) # this approach requires n_features = 4 qubits\n",
    "\n",
    "vqc = VQC(\n",
    "    feature_map=feature_map,\n",
    "    ansatz=ansatz,\n",
    "    loss=\"cross_entropy\",\n",
    "    optimizer=SPSA(maxiter=100)\n",
    ")\n",
    "\n",
    "vqc.fit(train_data, train_labels)\n",
    "\n",
    "print(f'Results with ZZFeatureMap: train set accuracy -> {vqc.score(train_data, train_labels)}, test set accuracy -> {vqc.score(test_data, test_labels)}.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "qiskit==1.1.0\n",
      "qiskit-aer==0.14.2\n",
      "qiskit-algorithms==0.3.0\n",
      "qiskit-ibm-runtime==0.25.0\n",
      "qiskit-machine-learning==0.7.2\n",
      "qiskit-nature==0.7.2\n",
      "qiskit-nature-pyscf==0.4.0\n",
      "qiskit-qasm3-import==0.5.0\n",
      "qiskit-transpiler-service==0.4.5\n"
     ]
    }
   ],
   "source": [
    "! pip freeze | grep qiskit"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
